#  -*- coding: utf-8 -*-
#  Copyright (c) 2024-2024 Huawei Technologies Co., Ltd.
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#  http://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.


from typing import Callable, Any, Literal, Annotated, Optional, List, Dict

import torch
import torch.distributed as dist
import torch.nn as nn
from pydantic import AfterValidator, Field

from msmodelslim.core.QAL.qregistry import QABCRegistry
from msmodelslim.quant.ir.norm_bias import RMSNormBias
from msmodelslim.quant.processor.anti_outlier.common import IterSmoothConfig, IterSmoothContext
from msmodelslim.quant.processor.anti_outlier.api.smooth_api import iter_smooth
from msmodelslim.core.base.protocol import BatchProcessRequest
from msmodelslim.quant.processor.anti_outlier.smooth_base import BaseSmoothProcessor, BaseSmoothProcessorConfig
from msmodelslim.quant.processor.anti_outlier.common.smooth_components import StatsCollector, SubgraphRegistry, StatKey
from msmodelslim.quant.processor.base import AutoSessionProcessor
from msmodelslim.quant.observer import MsMinMaxObserver, MinMaxObserverConfig
from msmodelslim.utils.logging import get_logger, logger_setter
from msmodelslim.utils.validation.value import validate_normalized_value, is_boolean, is_string_list


class IterSmoothProcessorConfig(BaseSmoothProcessorConfig):
    type: Literal["iter_smooth"] = "iter_smooth"
    alpha: Annotated[float, AfterValidator(validate_normalized_value)] = 0.9
    scale_min: Annotated[float, AfterValidator(validate_normalized_value)] = 1e-5
    symmetric: Annotated[bool, AfterValidator(is_boolean)] = True
    enable_subgraph_type: Annotated[list, AfterValidator(is_string_list)] = Field(
        default_factory=lambda: ["norm-linear", "linear-linear", "ov", "up-down"]
    )


class IterStatsCollector(StatsCollector):
    ASYM_SUPPORT_SUBGRAPH_TYPES = ["norm-linear"]

    def __init__(self, symmetric: bool):
        super().__init__()
        self.symmetric = symmetric
        self.minmax_observers: Dict[str, MsMinMaxObserver] = {}
        self.channel_max_observers: Dict[str, MsMinMaxObserver] = {}

    def create_hook(self, name: str, subgraph_type: str = None) -> Callable:
        def stats_hook(module: nn.Linear, input_tensor: tuple, output: Any) -> None:
            if not input_tensor or not isinstance(input_tensor, tuple):
                get_logger().warning(f"Input tensor is empty for module {name}")
                return

            tensor = input_tensor[0]
            if name not in self.act_stats:
                self.act_stats[name] = {}
                self.act_stats[name][StatKey.TENSOR] = tensor.cpu()

            hidden_dim = tensor.shape[-1]
            tensor = tensor.reshape(-1, hidden_dim).detach()

            statis_dict = self.act_stats[name]

            if name not in self.minmax_observers:
                observer_config = MinMaxObserverConfig(dim=0, keepdim=False)
                self.minmax_observers[name] = MsMinMaxObserver(observer_config)

            self.minmax_observers[name].update(tensor)
            coming_min, coming_max = self.minmax_observers[name].get_min_max()

            statis_dict[StatKey.STAT_KEY_MAX] = coming_max
            statis_dict[StatKey.STAT_KEY_MIN] = coming_min

            statis_dict[StatKey.STAT_KEY_SHIFT] = (coming_max + coming_min) / 2

            if name not in self.channel_max_observers:
                observer_config = MinMaxObserverConfig(dim=0, keepdim=False)
                self.channel_max_observers[name] = MsMinMaxObserver(observer_config)

            # 根据symmetric/asymmetric模式计算channel_max
            if not self.symmetric and subgraph_type in self.ASYM_SUPPORT_SUBGRAPH_TYPES:
                # asymmetric模式：计算shift后的绝对值最大值
                shifted_tensor = (tensor - statis_dict[StatKey.STAT_KEY_SHIFT]).abs()
                self.channel_max_observers[name].update(shifted_tensor)
            else:
                # symmetric模式：计算绝对值最大值
                abs_tensor = tensor.abs()
                self.channel_max_observers[name].update(abs_tensor)

            _, channel_max = self.channel_max_observers[name].get_min_max()
            statis_dict[StatKey.STAT_KEY_SMOOTH_SCALE] = channel_max

        return stats_hook

    def clear_stats(self) -> None:
        """清除统计信息和observer"""
        super().clear_stats()
        # 重置所有observer
        for observer in self.minmax_observers.values():
            observer.reset()
        for observer in self.channel_max_observers.values():
            observer.reset()
        self.minmax_observers.clear()
        self.channel_max_observers.clear()


@QABCRegistry.register(dispatch_key=IterSmoothProcessorConfig, abc_class=AutoSessionProcessor)
@logger_setter()
class IterSmoothProcessor(BaseSmoothProcessor):

    def __init__(self, model: nn.Module, config: IterSmoothProcessorConfig, adapter: object, **kwargs):
        super().__init__(model, config, adapter)
        self.config = config
        self._validate_parameters()
        self.stats_collector = IterStatsCollector(symmetric=config.symmetric)

        if not config.symmetric:
            supported_types = ', '.join(IterStatsCollector.ASYM_SUPPORT_SUBGRAPH_TYPES)
            get_logger().warning(
                f"Detected asymmetric IterSmooth; currently only supports {supported_types} subgraph types"
            )

    def support_distributed(self) -> bool:
        return True

    def apply_smooth_algorithm(self, subgraph_obj: Any, linear_names: List[str]) -> None:
        subgraph_type = SubgraphRegistry.get_name(type(subgraph_obj))
        if subgraph_type not in IterStatsCollector.ASYM_SUPPORT_SUBGRAPH_TYPES:
            shift_value = False
            get_logger().debug("Non-asym subgraph (%s), setting shift=False", subgraph_type)
        else:
            shift_value = not self.config.symmetric
            get_logger().debug("Asym-capable subgraph (%s), setting shift=%s", subgraph_type, shift_value)

        iter_smooth_cfg = IterSmoothConfig(
            alpha=self.config.alpha,
            shift=shift_value,
            scale_min=self.config.scale_min
        )
        smooth_context = self._build_smooth_context(linear_names)
        if smooth_context is None:
            get_logger().warning(
                "No statistics collected for %s subgraph, skipping. "
                "This may happen for unused MOE experts.",
                subgraph_type
            )
            return

        iter_smooth(subgraph_obj, iter_smooth_cfg, smooth_context)
        get_logger().info(
            "Successfully applied IterSmooth to %s subgraph (shift=%s)", subgraph_type, shift_value
        )

    def preprocess(self, request: BatchProcessRequest) -> None:
        super().preprocess(request)
        self._replace_norm_modules()
        get_logger().debug("Processed %d subgraphs for submodule %s", len(self.adapter_config), request.name)

    def _build_smooth_context(self, linear_names: List[str]) -> Optional[IterSmoothContext]:
        a_smooth_scale = None
        shift = None

        if not linear_names:
            get_logger().warning(
                "No linear modules provided while building IterSmoothContext; skipping smooth application."
            )
            return None
        # 仅用第一个linear的激活统计信息
        linear_name = linear_names[0]

        # 获取激活统计信息
        if linear_name in self.stats_collector.act_stats:
            stats = self.stats_collector.act_stats[linear_name]

            # 获取 smooth_scale
            if StatKey.STAT_KEY_SMOOTH_SCALE in stats:
                a_smooth_scale = stats[StatKey.STAT_KEY_SMOOTH_SCALE]
            else:
                a_smooth_scale = None

            # 获取 shift
            if StatKey.STAT_KEY_SHIFT in stats:
                shift = stats[StatKey.STAT_KEY_SHIFT]
            else:
                shift = None
        else:
            get_logger().warning(f"Linear name {linear_name} not in act_stats")
            return None

        # 检查是否成功获取到激活平滑尺度
        if a_smooth_scale is None:
            # 返回 None 而不是抛出异常，让调用者决定如何处理
            get_logger().debug(
                "Failed to get activation smooth scale from linear name {linear_name}. "
                "This may happen for unused subgraphs (e.g., unactivated MOE experts)."
            )
            return None
        # 创建 IterSmoothContext
        smooth_context = IterSmoothContext(
            version=1,
            a_smooth_scale=a_smooth_scale,
            shift=shift
        )

        return smooth_context

    def _replace_norm_modules(self) -> None:
        for adapter_config in self.adapter_config:
            if adapter_config.subgraph_type == "norm-linear":
                norm_name = adapter_config.mapping.source
                norm_module = self.model.get_submodule(norm_name) if norm_name else None
                if norm_name and norm_module is not None:
                    try:
                        if hasattr(norm_module, 'weight'):
                            norm_bias = RMSNormBias(norm_module.weight.shape[-1])
                            norm_bias.weight.data.copy_(norm_module.weight.data)
                            norm_bias.weight.data = norm_bias.weight.data.type(norm_module.weight.data.dtype)
                            if hasattr(norm_module, 'bias') and norm_module.bias is not None:
                                norm_bias.bias.data.copy_(norm_module.bias.data)
                                norm_bias.bias.data = norm_bias.bias.data.type(norm_module.weight.data.dtype)
                            norm_bias.to(norm_module.weight.data.device)
                            self.model.set_submodule(norm_name, norm_bias)
                            get_logger().debug("%s: %s -> %s", norm_name, type(norm_module), type(norm_bias))
                        else:
                            get_logger().warning("Norm module %s does not have weight attribute", norm_name)
                    except Exception as e:
                        get_logger().warning("Failed to replace norm module %s: %s", norm_name, e)
